#!/usr/bin/env python3
# Copyright (C) 2024 Checkmk GmbH - License: GNU General Public License v2
# This file is part of Checkmk (https://checkmk.com). It is subject to the terms and
# conditions defined in the file COPYING, which is part of this source code package.

# <<<splunk_license_state>>>
# Splunk_Enterprise_Splunk_Analytics_for_Hadoop_Download_Trial 5 30 524288000 1561977130 VALID
# Splunk_Forwarder 5 30 1048576 2147483647 VALID
# Splunk_Free 3 30 524288000 2147483647 VALID

import time
from collections.abc import Sequence
from typing import Literal, NewType, Self, TypedDict

import pydantic

from cmk.agent_based.v1.type_defs import StringTable
from cmk.agent_based.v2 import (
    AgentSection,
    check_levels,
    CheckPlugin,
    CheckResult,
    DiscoveryResult,
    FixedLevelsT,
    render,
    Result,
    Service,
    State,
)

DEFAULT_WARNING_EXPIRATION_TIME = 14 * 24 * 60 * 60
"""Default warning lower threshold (14 days in seconds)."""
DEFAULT_CRITICAL_EXPIRATION_TIME = 7 * 24 * 60 * 60
"""Default critical lower threshold (7 days in seconds)."""

LicenseStateLabel = NewType("LicenseStateLabel", str)
"""The label for the license state that serves as the discovery item."""

LicenseStatus = Literal["VALID", "EXPIRED"]
"""Valid license status values defined by splunk."""


class LicenseState(pydantic.BaseModel):
    """Describes the state of a license."""

    model_config = pydantic.ConfigDict(frozen=True)

    label: LicenseStateLabel
    """Name of the license."""
    max_violations: int = pydantic.Field(..., ge=0)
    """Maximum number of license violations within a given time window."""
    window: int = pydantic.Field(..., ge=0)
    """Window period defined in days."""
    quota: int = pydantic.Field(..., ge=0)
    """Data quota defined in bytes."""
    expiration: int = pydantic.Field(..., ge=0)
    """Expiration time defined in seconds."""
    status: LicenseStatus
    """Status of the license."""

    @classmethod
    def from_string_table_item(cls, table: Sequence[str]) -> Self:
        """Build and validate the input from a string table item passed by the agent."""
        payload = dict(zip(cls.__pydantic_fields__, table))
        return cls.model_validate_strings(payload)

    def calculate_time_to_expiration(self, now: float) -> float:
        """Calculates the time in seconds until the license is expired."""
        return self.expiration - now


type LicenseStateSection = dict[LicenseStateLabel, LicenseState]
"""The output generated by the parsing function."""


def parse_splunk_license_state(string_table: StringTable) -> LicenseStateSection:
    """Parse splunk license states from agent output."""
    section: LicenseStateSection = {}

    for item in string_table:
        try:
            license_state = LicenseState.from_string_table_item(item)
            section[license_state.label] = license_state
        except pydantic.ValidationError:
            continue

    return section


def discover_splunk_license_state(section: LicenseStateSection) -> DiscoveryResult:
    """Discovers splunk license state services from parsed agent section."""
    yield from (Service(item=item) for item in section)


type StateValue = Literal[0, 1, 2, 3]
"""A valid integer code related to the result state."""

type IntLevels = FixedLevelsT[int]
"""Fixed warn and critical integer threshold."""


class CheckParams(TypedDict):
    """Parameters passed to plugin via ruleset (see defaults)."""

    state: StateValue
    expiration_time: IntLevels


def check(data: LicenseState, params: CheckParams, *, now: float) -> CheckResult:
    """Check a given license state yielding results and metrics."""
    match data.status:
        case "EXPIRED":
            yield Result(
                state=State(params["state"]),
                summary=f"Status: {data.status} on {render.datetime(data.expiration)}",
            )
        case "VALID":
            yield Result(
                state=State.OK,
                summary=f"Status: {data.status} until {render.datetime(data.expiration)}",
            )

    if (time_to_expiration := data.calculate_time_to_expiration(now)) > 0:
        yield from check_levels(
            time_to_expiration,
            levels_lower=params["expiration_time"],
            render_func=render.timespan,
            label="Expiration time",
        )

    yield Result(
        state=State.OK,
        summary=f"Max violations: {data.max_violations} within window period of {data.window} days",
    )
    yield Result(state=State.OK, summary=f"Quota: {render.bytes(data.quota)}")


def check_splunk_license_state(
    item: LicenseStateLabel, params: CheckParams, section: LicenseStateSection
) -> CheckResult:
    """Checks the splunk license state section returning valid checkmk results."""
    if not (data := section.get(item)):
        return

    yield from check(data, params, now=time.time())


agent_section_splunk_license_state = AgentSection(
    name="splunk_license_state",
    parse_function=parse_splunk_license_state,
)

check_plugin_splunk_license_state = CheckPlugin(
    name="splunk_license_state",
    service_name="Splunk License %s",
    discovery_function=discover_splunk_license_state,
    check_function=check_splunk_license_state,
    check_ruleset_name="splunk_license_state",
    check_default_parameters=CheckParams(
        state=State.CRIT.value,
        expiration_time=(
            "fixed",
            (
                DEFAULT_WARNING_EXPIRATION_TIME,
                DEFAULT_CRITICAL_EXPIRATION_TIME,
            ),
        ),
    ),
)
